import numpy as np
import pandas as pd
from pysankey import sankey
from matplotlib.lines import Line2D
import matplotlib.pylab as plt
envSDGs = [6,7,11,12,13,14,15]



class Visual():
    def __init__(self):
        pass
    def make_sankey(self, X):
        #== Make sankey plot
        sdgs = envSDGs
        z = X[X.columns[-98:]]
        sk = []
        for i in z.columns[::-1]:
            action = i.split(' - ')[0]
            sdg = i.split(' - ')[1]
            repetition = z[i].sum()
            l = [[action, sdg, repetition]]
            sk.extend(l*int(repetition))

        sk = pd.DataFrame(sk, columns = ['Action', 'SDG', 'weight'])
        sk['SDG'] = sk['SDG'].apply(lambda x: int(x.split('SDG ')[1]))
        sk = sk.sort_values(by = ['Action', 'SDG'], ascending = False)
        sk['SDG'] = sk['SDG'].apply(lambda x: 'SDG '+str(x))
        sk['Action'] = sk['Action'].apply(lambda x: x.capitalize())
        sk['Action'] = sk['Action'].apply(lambda x: x.replace('R&d', 'R&D'))
        actions = list(sk.Action.unique())
        overall = actions + ['SDG '+str(s) for s in sdgs]
        cols = ['#EBEEEE', '#9D9D9D', '#002147', '#008AFF', '#00BEFA', '#B5F0E7',
                '#FFF5BE', '#EC7300', '#D24000', '#E40428','#e08b00', '#c2b300',
                '#92d600', '#08CDAE'][::-1] + [  '#08CDAE']*len(envSDGs)
        col_dict = {overall[i]: cols[i] for i in range(len(cols))}     
        #==== Assign the SDG color to be the one of the most common action    
        max_mapping = sk.groupby(['Action', 'SDG']).last().reset_index().sort_values(by = ['SDG', 'weight']).groupby('SDG').last().reset_index()[['SDG', 'Action']]
        for n in max_mapping['SDG']:
            col_dict[n] = col_dict[max_mapping[max_mapping.SDG == n]['Action'].iloc[0]]
        #====
        plt.figure(figsize =(30,32))
        ax = plt.subplot()
        sankey(left = sk['Action'], 
               right = sk['SDG'], 
               aspect=140,
            fontsize=40, colorDict=col_dict, ax = ax)
        
    def NReports(self, Z):
        X = Z.copy()
        ax = plt.subplot()
        X[['rfyear', 'ISIN']].groupby('rfyear').count().plot(marker = 'o', lw = 2, color = 'navy', legend = False, ax = ax)
        plt.ylabel('Number of reports', color='navy')
        plt.xlabel('')
        X['number_of_initiatives'] = X[X.columns[-98:]].sum(axis = 1)
        ax2 = plt.twinx()
        X[['rfyear', 'number_of_initiatives']].groupby('rfyear').sum().plot(marker = 'o', lw = 2, color = 'darkorange', ax = ax2, legend = False)        
        plt.xlabel('')
        plt.ylabel('Number of initiatives', color='darkorange')

        
    def SumStat(self, Z):
        X = Z.copy()
        X['number_of_initiatives'] = X[X.columns[-98:]].sum(axis = 1)
        a = X[['rfyear', 'ISIN']].groupby('rfyear').nunique().astype(int)
        b = X[['rfyear', 'number_of_initiatives']].groupby('rfyear').sum().astype(int)
        d = pd.concat((a, b), axis = 1)
        d.columns = ['number of reports', 'number of initiatives']
        return d
        
        